# Copyright (c) Facebook, Inc. and its affiliates.
import logging
from typing import Dict

from torch import nn
from torch.nn import functional as F

from efg.data.structures.shape_spec import ShapeSpec
from efg.modeling.common import weight_init
from efg.modeling.common.blocks import Conv2d

from pixel_decoder.fpn import build_pixel_decoder
from transformer_decoder.maskformer_transformer_decoder import StandardTransformerDecoder


class PerPixelBaselineHead(nn.Module):
    _version = 2

    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        version = local_metadata.get("version", None)
        if version is None or version < 2:
            logger = logging.getLogger(__name__)
            # Do not warn if train from scratch
            scratch = True
            logger = logging.getLogger(__name__)
            for k in list(state_dict.keys()):
                newk = k
                if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
                    newk = k.replace(prefix, prefix + "pixel_decoder.")
                    # logger.warning(f"{k} ==> {newk}")
                if newk != k:
                    state_dict[newk] = state_dict[k]
                    del state_dict[k]
                    scratch = False

            if not scratch:
                logger.warning(
                    f"Weight format of {self.__class__.__name__} have changed! "
                    "Please upgrade your models. Applying automatic conversion now ..."
                )

    def __init__(
        self,
        config,
        input_shape: Dict[str, ShapeSpec],
    ):
        """
        NOTE: this interface is experimental.
        Args:
            input_shape: shapes (channels and stride) of the input features
            num_classes: number of classes to predict
            pixel_decoder: the pixel decoder module
            loss_weight: loss weight
            ignore_value: category id to be ignored during training.
        """
        super().__init__()

        input_shape = {k: v for k, v in input_shape.items() if k in config.MODEL.SEM_SEG_HEAD.IN_FEATURES}
        ignore_value = config.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
        num_classes = config.MODEL.SEM_SEG_HEAD.NUM_CLASSES
        pixel_decoder = build_pixel_decoder(config, input_shape)
        loss_weight = config.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT

        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
        self.in_features = [k for k, v in input_shape]
        # feature_strides = [v.stride for k, v in input_shape]
        # feature_channels = [v.channels for k, v in input_shape]

        self.ignore_value = ignore_value
        self.common_stride = 4
        self.loss_weight = loss_weight

        self.pixel_decoder = pixel_decoder
        self.predictor = Conv2d(self.pixel_decoder.mask_dim, num_classes, kernel_size=1, stride=1, padding=0)
        weight_init.c2_msra_fill(self.predictor)

    def forward(self, features, targets=None):
        """
        Returns:
            In training, returns (None, dict of losses)
            In inference, returns (CxHxW logits, {})
        """
        x = self.layers(features)
        if self.training:
            return None, self.losses(x, targets)
        else:
            x = F.interpolate(x, scale_factor=self.common_stride, mode="bilinear", align_corners=False)
            return x, {}

    def layers(self, features):
        x, _, _ = self.pixel_decoder.forward_features(features)
        x = self.predictor(x)
        return x

    def losses(self, predictions, targets):
        predictions = predictions.float()  # https://github.com/pytorch/pytorch/issues/48163
        predictions = F.interpolate(predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False)
        loss = F.cross_entropy(predictions, targets, reduction="mean", ignore_index=self.ignore_value)
        losses = {"loss_sem_seg": loss * self.loss_weight}
        return losses


class PerPixelBaselinePlusHead(PerPixelBaselineHead):
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        version = local_metadata.get("version", None)
        if version is None or version < 2:
            # Do not warn if train from scratch
            scratch = True
            logger = logging.getLogger(__name__)
            for k in list(state_dict.keys()):
                newk = k
                if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
                    newk = k.replace(prefix, prefix + "pixel_decoder.")
                    logger.debug(f"{k} ==> {newk}")
                if newk != k:
                    state_dict[newk] = state_dict[k]
                    del state_dict[k]
                    scratch = False

            if not scratch:
                logger.warning(
                    f"Weight format of {self.__class__.__name__} have changed! "
                    "Please upgrade your models. Applying automatic conversion now ..."
                )

    def __init__(self, config, input_shape: Dict[str, ShapeSpec]):
        """
        NOTE: this interface is experimental.
        Args:
            input_shape: shapes (channels and stride) of the input features
            transformer_predictor: the transformer decoder that makes prediction
            transformer_in_feature: input feature name to the transformer_predictor
            deep_supervision: whether or not to add supervision to the output of
                every transformer decoder layer
            num_classes: number of classes to predict
            pixel_decoder: the pixel decoder module
            loss_weight: loss weight
            ignore_value: category id to be ignored during training.
        """
        input_shape = {k: v for k, v in input_shape.items() if k in config.MODEL.SEM_SEG_HEAD.IN_FEATURES}
        ignore_value = config.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
        num_classes = config.MODEL.SEM_SEG_HEAD.NUM_CLASSES
        pixel_decoder = build_pixel_decoder(config, input_shape)
        loss_weight = config.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT

        transformer_in_feature = config.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE
        if config.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder":
            in_channels = config.MODEL.SEM_SEG_HEAD.CONVS_DIM
        else:
            in_channels = input_shape[transformer_in_feature].channels
        transformer_predictor = StandardTransformerDecoder(config, in_channels, mask_classification=False)
        deep_supervision = config.MODEL.MASK_FORMER.DEEP_SUPERVISION

        super().__init__(
            input_shape,
            num_classes=num_classes,
            pixel_decoder=pixel_decoder,
            loss_weight=loss_weight,
            ignore_value=ignore_value,
        )

        del self.predictor

        self.predictor = transformer_predictor
        self.transformer_in_feature = transformer_in_feature
        self.deep_supervision = deep_supervision

    def forward(self, features, targets=None):
        """
        Returns:
            In training, returns (None, dict of losses)
            In inference, returns (CxHxW logits, {})
        """
        x, aux_outputs = self.layers(features)
        if self.training:
            if self.deep_supervision:
                losses = self.losses(x, targets)
                for i, aux_output in enumerate(aux_outputs):
                    losses["loss_sem_seg" + f"_{i}"] = self.losses(aux_output["pred_masks"], targets)["loss_sem_seg"]
                return None, losses
            else:
                return None, self.losses(x, targets)
        else:
            x = F.interpolate(x, scale_factor=self.common_stride, mode="bilinear", align_corners=False)
            return x, {}

    def layers(self, features):
        mask_features, transformer_encoder_features, _ = self.pixel_decoder.forward_features(features)
        if self.transformer_in_feature == "transformer_encoder":
            assert transformer_encoder_features is not None, "Please use the TransformerEncoderPixelDecoder."
            predictions = self.predictor(transformer_encoder_features, mask_features)
        else:
            predictions = self.predictor(features[self.transformer_in_feature], mask_features)
        if self.deep_supervision:
            return predictions["pred_masks"], predictions["aux_outputs"]
        else:
            return predictions["pred_masks"], None
